using Fantasy.SourceGenerator.Common;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Fantasy.SourceGenerator.Generators
{
    /// <summary>
    /// Message Dispatcher 注册代码生成器
    /// 自动生成 MessageDispatcherComponent 所需的注册代码，替代运行时反射
    /// </summary>
    [Generator]
    public partial class MessageDispatcherGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            // 查找所有实现了消息相关接口的类
            var messageTypes = context.SyntaxProvider
                .CreateSyntaxProvider(
                    predicate: static (node, _) => IsMessageRelatedClass(node),
                    transform: static (ctx, _) => GetMessageTypeInfo(ctx))
                .Where(static info => info != null)
                .Collect();
            // 组合编译信息和找到的类型
            var compilationAndTypes = context.CompilationProvider.Combine(messageTypes);
            // 注册源代码输出
            context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
            {
                // 检查1: 是否定义了 FANTASY_NET 或 FANTASY_UNITY 预编译符号
                if (!CompilationHelper.HasFantasyDefine(source.Left))
                {
                    return;
                }

                // 检查2: 是否引用了 Fantasy 框架的核心类型
                if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.IMessageDispatcherRegistrar") == null)
                {
                    return;
                }

                GenerateRegistrationCode(spc, source.Left, source.Right!);
            });
        }
        
        /// <summary>
        /// 提取消息类型信息
        /// </summary>
        private static MessageDispatcherInfo? GetMessageTypeInfo(GeneratorSyntaxContext context)
        {
            var classDecl = (ClassDeclarationSyntax)context.Node;

            if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol || !symbol.IsInstantiable())
            {
                return null;
            }

            foreach (var interfaceSymbol in symbol.AllInterfaces)
            {
                var messageDispatcherInfo = new MessageDispatcherInfo(MessageType.None, symbol.GetFullName(), symbol.Name);
                switch (interfaceSymbol.ToDisplayString())
                {
                    case GeneratorConstants.MessageInterfaces.IMessage:
                    {
                        messageDispatcherInfo.RouteType = symbol.GetMembers("RouteType").OfType<IPropertySymbol>().FirstOrDefault();
                        messageDispatcherInfo.ResponseType = symbol.GetMembers("ResponseType").OfType<IPropertySymbol>().FirstOrDefault();
                        messageDispatcherInfo.MessageType = MessageType.MessageProtocol;
                        return messageDispatcherInfo;
                    }
                    case GeneratorConstants.MessageInterfaces.IMessageHandler:
                    {
                        messageDispatcherInfo.MessageType = MessageType.MessageHandler;
                        return messageDispatcherInfo;
                    }
                    case GeneratorConstants.MessageInterfaces.IRouteMessageHandler:
                    {
                        messageDispatcherInfo.MessageType = MessageType.RouteMessageHandler;
                        return messageDispatcherInfo;
                    }
                }
            }
            
            return null;
        }

        private static void GenerateRegistrationCode(
            SourceProductionContext context, 
            Compilation compilation,
            IEnumerable<MessageDispatcherInfo> messageDispatcherInfos)
        {
            var messageDispatcherInfoList = messageDispatcherInfos.ToList();
            var assemblyName = compilation.AssemblyName ?? "Unknown";
            // 生成代码文件
            var builder = new SourceCodeBuilder();
            // 添加文件头
            builder.AppendLine(GeneratorConstants.AutoGeneratedHeader); 
            // 添加 using
            builder.AddUsings(
                "System",
                "System.Collections.Generic",
                "Fantasy.Assembly",
                "Fantasy.DataStructure.Dictionary",
                "Fantasy.Network.Interface"
            );
            builder.AppendLine();
            // 开始命名空间（固定使用 Fantasy.Generated）
            builder.BeginNamespace("Fantasy.Generated");
            // 开始类定义（实现 IMessageDispatcherRegistrar 接口）
            builder.AddXmlComment($"Auto-generated Message Dispatcher registration class for {assemblyName}");
            builder.BeginClass("MessageDispatcherRegistrar", "internal sealed", "IMessageDispatcherRegistrar");
            // 生成字段用于存储已注册的实例（用于 UnRegister）
            GenerateFields(builder, messageDispatcherInfoList);
            // 生成 RegisterSystems 方法
            GenerateRegistrationCode(builder, messageDispatcherInfoList);
            // 生成 UnRegisterSystems 方法
            GenerateUnRegisterMethod(builder, messageDispatcherInfoList);
            // 生成 Dispose 方法
            GenerateDisposeMethod(builder, messageDispatcherInfoList);
            // 结束类和命名空间
            builder.EndClass();
            builder.EndNamespace();
            // 输出源代码
            context.AddSource("MessageDispatcherRegistrar.g.cs", builder.ToString());
        }

        private static void GenerateFields(SourceCodeBuilder builder, List<MessageDispatcherInfo> messageDispatcherInfos)
        {
            builder.AddComment("Store registered instances for UnRegister");

            if (messageDispatcherInfos.Any())
            {
                foreach (var messageDispatcherInfo in messageDispatcherInfos)
                {
                    switch (messageDispatcherInfo.MessageType)
                    {
                        case MessageType.MessageProtocol:
                        {
                            builder.AppendLine($"private uint opCode_{messageDispatcherInfo.TypeName};");
                            continue;
                        }
                        case MessageType.MessageHandler:
                        {
                            builder.AppendLine($"private Type messageType_{messageDispatcherInfo.TypeName};");
                            continue;
                        }
                        case MessageType.RouteMessageHandler:
                        {
                            builder.AppendLine($"private Type routeMessageType_{messageDispatcherInfo.TypeName};");
                            continue;
                        }
                    }
                }
                builder.AppendLine();
            }
        }

        private static void GenerateRegistrationCode(SourceCodeBuilder builder, List<MessageDispatcherInfo> messageDispatcherInfos)
        {
            builder.AddXmlComment("Register all message systems to the dictionaries");
            builder.AppendLine("#if FANTASY_NET", false);
            builder.BeginMethod(
                "public void RegisterSystems(" +
                "DoubleMapDictionary<uint, Type> networkProtocols, " +
                "Dictionary<Type, Type> responseTypes, " +
                "Dictionary<Type, IMessageHandler> messageHandlers, " +
                "Dictionary<long, int> customRouteMap, " +
                "Dictionary<Type, IRouteMessageHandler> routeMessageHandlers)");
            builder.AppendLine("#endif", false);
            builder.Unindent();
            builder.AppendLine("#if FANTASY_UNITY", false);
            builder.BeginMethod(
                "public void RegisterSystems(" +
                "DoubleMapDictionary<uint, Type> networkProtocols, " +
                "Dictionary<Type, Type> responseTypes, " +
                "Dictionary<Type, IMessageHandler> messageHandlers)");
            builder.AppendLine("#endif", false);
            
            if (messageDispatcherInfos.Any())
            {
                var netSourceCodeBuilder = new SourceCodeBuilder();
                netSourceCodeBuilder.AppendLine("#if FANTASY_NET", false);
                netSourceCodeBuilder.Indent(3);
                foreach (var messageDispatcherInfo in messageDispatcherInfos)
                {
                    switch (messageDispatcherInfo.MessageType)
                    {
                        case MessageType.MessageProtocol:
                        {
                            var fieldName = $"{messageDispatcherInfo.TypeName.ToCamelCase()}";
                            builder.AppendLine($"var {fieldName} = new {messageDispatcherInfo.TypeFullName}();");
                            builder.AppendLine($"opCode_{messageDispatcherInfo.TypeName} = {fieldName}.OpCode();");
                            builder.AppendLine($"networkProtocols.Add(opCode_{messageDispatcherInfo.TypeName}, typeof({messageDispatcherInfo.TypeFullName}));");

                            if (messageDispatcherInfo.ResponseType != null)
                            {
                                builder.AppendLine($"responseTypes.Add(typeof({messageDispatcherInfo.TypeFullName}), typeof({messageDispatcherInfo.ResponseType.Type.GetFullName()}));");
                            }

                            if (messageDispatcherInfo.RouteType != null)
                            {
                                netSourceCodeBuilder.AppendLine($"customRouteMap[opCode_{messageDispatcherInfo.TypeName}] = {fieldName}.RouteType;");
                            }
                            continue;
                        }
                        case MessageType.MessageHandler:
                        {
                            var fieldName = $"{messageDispatcherInfo.TypeName.ToCamelCase()}";
                            builder.AppendLine($"var {fieldName} = new {messageDispatcherInfo.TypeFullName}();");
                            builder.AppendLine($"messageType_{messageDispatcherInfo.TypeName} = {fieldName}.Type();");
                            builder.AppendLine($"messageHandlers.Add(messageType_{messageDispatcherInfo.TypeName}, {fieldName});");
                            continue;
                        }
                        case MessageType.RouteMessageHandler:
                        {
                            var fieldName = $"{messageDispatcherInfo.TypeName.ToCamelCase()}";
                            builder.AppendLine($"var {fieldName} = new {messageDispatcherInfo.TypeFullName}();");
                            builder.AppendLine($"routeMessageType_{messageDispatcherInfo.TypeName} = {fieldName}.Type();");
                            builder.AppendLine($"routeMessageHandlers.Add(routeMessageType_{messageDispatcherInfo.TypeName}, {fieldName});");
                            continue;
                        }
                    }
                }
                netSourceCodeBuilder.Append("#endif");
                builder.AppendLine(netSourceCodeBuilder.ToString(), false);
            }
            
            builder.EndMethod();
            builder.AppendLine();
        }

        private static void GenerateUnRegisterMethod(SourceCodeBuilder builder, List<MessageDispatcherInfo> messageDispatcherInfos)
        {
            builder.AddXmlComment("Unregister all message systems from the dictionaries");
            builder.AppendLine("#if FANTASY_NET", false);
            builder.BeginMethod(
                "public void UnRegisterSystems(" +
                "DoubleMapDictionary<uint, Type> networkProtocols, " +
                "Dictionary<Type, Type> responseTypes, " +
                "Dictionary<Type, IMessageHandler> messageHandlers, " +
                "Dictionary<long, int> customRouteMap, " +
                "Dictionary<Type, IRouteMessageHandler> routeMessageHandlers)");
            builder.AppendLine("#endif", false);
            builder.Unindent();
            builder.AppendLine("#if FANTASY_UNITY", false);
            builder.BeginMethod(
                "public void UnRegisterSystems(" +
                "DoubleMapDictionary<uint, Type> networkProtocols, " +
                "Dictionary<Type, Type> responseTypes, " +
                "Dictionary<Type, IMessageHandler> messageHandlers)");
            builder.AppendLine("#endif", false);
            
            if (messageDispatcherInfos.Any())
            {
                var netSourceCodeBuilder = new SourceCodeBuilder();
                netSourceCodeBuilder.AppendLine("#if FANTASY_NET", false);
                netSourceCodeBuilder.Indent(3);
                foreach (var messageDispatcherInfo in messageDispatcherInfos)
                {
                    switch (messageDispatcherInfo.MessageType)
                    {
                        case MessageType.MessageProtocol:
                        {
                            builder.AppendLine($"networkProtocols.RemoveByKey(opCode_{messageDispatcherInfo.TypeName});");
                            
                            if (messageDispatcherInfo.ResponseType != null)
                            {
                                builder.AppendLine($"responseTypes.Remove(typeof({messageDispatcherInfo.TypeFullName}));");
                            }

                            if (messageDispatcherInfo.RouteType != null)
                            {
                                netSourceCodeBuilder.AppendLine($"customRouteMap.Remove(opCode_{messageDispatcherInfo.TypeName});");
                            }
                            continue;
                        }
                        case MessageType.MessageHandler:
                        {
                            builder.AppendLine($"messageHandlers.Remove(messageType_{messageDispatcherInfo.TypeName});");
                            continue;
                        }
                        case MessageType.RouteMessageHandler:
                        {
                            builder.AppendLine($"routeMessageHandlers.Remove(routeMessageType_{messageDispatcherInfo.TypeName});");
                            continue;
                        }
                    }
                }
                netSourceCodeBuilder.Append("#endif");
                builder.AppendLine(netSourceCodeBuilder.ToString(), false);
            }
            
            builder.EndMethod();
            builder.AppendLine();
        }
        
        private static void GenerateDisposeMethod(SourceCodeBuilder builder, List<MessageDispatcherInfo> messageDispatcherInfos)
        {
            builder.AddXmlComment("Dispose all resources");
            builder.BeginMethod("public void Dispose()");
            builder.AddComment("Clear all references");

            // 生成清空字段的代码
            
            if (messageDispatcherInfos.Any())
            {
                foreach (var messageDispatcherInfo in messageDispatcherInfos)
                {
                    switch (messageDispatcherInfo.MessageType)
                    {
                        case MessageType.MessageProtocol:
                        {
                            builder.AppendLine($"opCode_{messageDispatcherInfo.TypeName} = 0;");
                            continue;
                        }
                        case MessageType.MessageHandler:
                        {
                            builder.AppendLine($"messageType_{messageDispatcherInfo.TypeName} = null;");
                            continue;
                        }
                        case MessageType.RouteMessageHandler:
                        {
                            builder.AppendLine($"routeMessageType_{messageDispatcherInfo.TypeName} = null;");
                            continue;
                        }
                    }
                }
            }

            builder.EndMethod();
        }
        
        /// <summary>
        /// 快速判断语法节点是否可能是消息相关类
        /// </summary>
        private static bool IsMessageRelatedClass(SyntaxNode node)
        {
            if (node is not ClassDeclarationSyntax classDecl)
            {
                return false;
            }

            return classDecl.BaseList != null && classDecl.BaseList.Types.Any();
        }
        
        private enum MessageType
        {
            None,
            MessageProtocol,        // IMessage 协议类
            MessageHandler,         // Message<T> / MessageRPC<T,R> 处理器
            RouteMessageHandler     // Route<E,M> / RouteRPC<E,R,S> 等路由处理器
        }
        
        private sealed class MessageDispatcherInfo
        {
            public MessageType MessageType;
            public readonly string TypeFullName;
            public readonly string TypeName;
            public IPropertySymbol? ResponseType;
            public IPropertySymbol? RouteType;

            public MessageDispatcherInfo(MessageType messageType, string typeFullName, string typeName)
            {
                MessageType = messageType;
                TypeFullName = typeFullName;
                TypeName = typeName;
            }
        }
    }
}
